import glob
import os
import random

import cv2
import h5py
import numpy as np
import matplotlib.pyplot as plt

# 验证的数据百分比
from h5 import create_h5, get_dataset

validate_PERCENTAGE = 0
# 测试的数据百分比
TEST_PERCENTAGE = 20
# 图片的最小宽和高，如果图片小于这个值，直接抛弃，不放入数据集
MIN_IMAGE_WIDTH = 300
MIN_IMAGE_HEIGHT = 300
# 指定图片尺寸
IM_WIDTH, IM_HEIGHT = 299, 299


# 返回一个目录path所属的label标签
def belong_label(path, input_data, labels=[]):
    for label in labels:
        label_path = os.path.join(input_data, label)
        if path.startswith(label_path):
            return label
    return None


# 把样本中所有的图片列表并按训练、验证、测试数据分开
def create_image_lists(input_data, labels=[], sample_num=120):
    result = {}

    # for label in labels:
    #     result[label] = {
    #         'train': [],
    #         'test': [],
    #         'validate': []
    #     }

    is_root_dir = True
    for sub_dir, dirs, files in os.walk(input_data):
        # 如果是根目录，即和input_data一样，不遍历
        if is_root_dir:
            labels = dirs
            for label in labels:
                result[label] = {
                    'train': [],
                    'test': [],
                    'validate': []
                }
            is_root_dir = False
            continue
        if len(files) < 1:
            continue

        label_name = belong_label(sub_dir, input_data, labels)
        if label_name == None:
            continue

        train_images = []
        test_images = []
        validate_images = []
        for file_name in files:
            file = os.path.join(sub_dir, file_name)
            # image = cv2.imread(file, 1)
            # # (h*w*c)高度*宽度*深度
            # shape = image.shape
            # if shape[0] < MIN_IMAGE_HEIGHT or shape[1] < MIN_IMAGE_WIDTH:
            #     os.remove(file)
            #     continue

            chance = np.random.randint(100)
            if chance < validate_PERCENTAGE:
                validate_images.append(file)
            elif chance < (TEST_PERCENTAGE + validate_PERCENTAGE):
                test_images.append(file)
            else:
                train_images.append(file)
        result[label_name]['train'].extend(train_images)
        result[label_name]['test'].extend(test_images)
        result[label_name]['validate'].extend(validate_images)
    for label in labels:
        # rand = np.random.randint(20)
        # nn = sample_num - rand
        train_list = result[label]['train']
        random.shuffle(train_list)
        random.shuffle(result[label]['test'])
        # result[label]['train'] = train_list[:nn]
    return result


def create_dataset(input_data, labels=[]):
    image_list = create_image_lists(input_data, labels)
    create_h5(image_list, input_data, 'train')
    create_h5(image_list, input_data, 'test')
    # create_h5(image_list, input_data, 'validate')


def load_dataset(base_dir):
    print('load dataset start...')
    print('load train dataset start...')
    train_dir = os.path.join(base_dir, 'train.h5')
    train_dataset = get_dataset(train_dir)
    train_set_x_orig = np.array(train_dataset["train_set_x"][:])  # your train set features
    train_set_y_orig = np.array(train_dataset["train_set_y"][:])  # your train set labels
    print('load train dataset end')
    print('load test dataset start...')
    test_dir = os.path.join(base_dir, 'test.h5')
    test_dataset = get_dataset(test_dir)
    test_set_x_orig = np.array(test_dataset["test_set_x"][:])  # your test set features
    test_set_y_orig = np.array(test_dataset["test_set_y"][:])  # your test set labels
    print('load test dataset end')
    print('load dataset end')
    # train_set_y_orig = train_set_y_orig.T
    # test_set_y_orig = test_set_y_orig.T

    return train_set_x_orig, train_set_y_orig, test_set_x_orig, test_set_y_orig


# input_data = 'D:/cdn/filtered'

input_data = 'E:/cdn/comment'
# result = create_image_lists(input_data, ['useful', 'useless'])
# for label in result.keys():
#     list = result[label]['train']
#     num = len(list)
#     print('label:' + label + ',num is:' + str(num))
# create_dataset(input_data, ['useful', 'useless'])

# create_dataset(input_data)

# print(X_train.shape, y_train.shape)
# print(y_train)

# train_set_x_orig, train_set_y_orig, test_set_x_orig, test_set_y_orig = load_dataset(input_data)
# print(train_set_x_orig.shape,train_set_y_orig.shape,test_set_x_orig.shape,test_set_y_orig.shape)
#
# index = 6
# plt.imshow(train_set_x_orig[index])
# print("y = " + str(np.squeeze(train_set_y_orig[index])))
# plt.show()
